{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import sys, os, _pickle as pickle\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import nltk\n",
    "from sklearn.metrics import f1_score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "data_dir = '../data'                        # Directory for Data and Other files\n",
    "ckpt_dir = '../checkpoint'                  # Directory for Checkpoints \n",
    "word_embd_dir = '../checkpoint/word_embd'   # Directory for Checkpoints of Word Embedding Layer\n",
    "model_dir = '../checkpoint/modelv1'         # Directory for Checkpoints of Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "word_embd_dim = 100       # Dimension of embedding layer for words\n",
    "pos_embd_dim = 25         # Dimension of embedding layer for POS Tags\n",
    "dep_embd_dim = 25         # Dimension of embedding layer for Dependency Types\n",
    "\n",
    "word_vocab_size = 400001  # Vocab size for Words\n",
    "pos_vocab_size = 10       # Vocab size for POS Tags\n",
    "dep_vocab_size = 21       # Vocab size for Dependency Types\n",
    "\n",
    "relation_classes = 19     # No. of Relation Classes\n",
    "state_size = 100          # Dimension of States of LSTM-RNNs\n",
    "batch_size = 10           # Batch Size for training\n",
    "\n",
    "channels = 3              # No. of types of features to feed in LSTM-RNN\n",
    "lambda_l2 = 0.0001\n",
    "max_len_path = 10         # Maximum length of sequence"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "with tf.name_scope(\"input\"):\n",
    "    \n",
    "    # Length of the sequence\n",
    "    path_length = tf.placeholder(tf.int32, shape=[2, batch_size], name=\"path1_length\") \n",
    "    \n",
    "    # Words in the sequence\n",
    "    word_ids = tf.placeholder(tf.int32, shape=[2, batch_size, max_len_path], name=\"word_ids\") \n",
    "    \n",
    "     # POS Tags in the sequence\n",
    "    pos_ids = tf.placeholder(tf.int32, [2, batch_size, max_len_path], name=\"pos_ids\") \n",
    "    \n",
    "    # Dependency Types in the sequence\n",
    "    dep_ids = tf.placeholder(tf.int32, [2, batch_size, max_len_path], name=\"dep_ids\") \n",
    "    \n",
    "     # True Relation btw the entities\n",
    "    y = tf.placeholder(tf.int32, [batch_size], name=\"y\")   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Embedding Layer of Words \n",
    "with tf.name_scope(\"word_embedding\"):\n",
    "    W = tf.Variable(tf.constant(0.0, shape=[word_vocab_size, word_embd_dim]), name=\"W\")\n",
    "    embedding_placeholder = tf.placeholder(tf.float32,[word_vocab_size, word_embd_dim])\n",
    "    embedding_init = W.assign(embedding_placeholder)\n",
    "    embedded_word = tf.nn.embedding_lookup(W, word_ids)\n",
    "    word_embedding_saver = tf.train.Saver({\"word_embedding/W\": W})\n",
    "\n",
    "# Embedding Layer of POS Tags \n",
    "with tf.name_scope(\"pos_embedding\"):\n",
    "    W = tf.Variable(tf.random_uniform([pos_vocab_size, pos_embd_dim]), name=\"W\")\n",
    "    embedded_pos = tf.nn.embedding_lookup(W, pos_ids)\n",
    "    pos_embedding_saver = tf.train.Saver({\"pos_embedding/W\": W})\n",
    "\n",
    "# Embedding Layer of Dependency Types \n",
    "with tf.name_scope(\"dep_embedding\"):\n",
    "    W = tf.Variable(tf.random_uniform([dep_vocab_size, dep_embd_dim]), name=\"W\")\n",
    "    embedded_dep = tf.nn.embedding_lookup(W, dep_ids)\n",
    "    dep_embedding_saver = tf.train.Saver({\"dep_embedding/W\": W})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "hidden_states = tf.zeros([channels, batch_size, state_size], name=\"hidden_state\")\n",
    "cell_states = tf.zeros([channels, batch_size, state_size], name=\"cell_state\")\n",
    "\n",
    "init_states = [tf.contrib.rnn.LSTMStateTuple(hidden_states[i], cell_states[i]) for i in range(channels)]\n",
    "\n",
    "with tf.variable_scope(\"word_lstm1\"):\n",
    "    cell = tf.contrib.rnn.BasicLSTMCell(word_state_size)\n",
    "    state_series, current_state = tf.nn.dynamic_rnn(cell, embedded_word[0], sequence_length=path_length[0], initial_state=init_states[0])\n",
    "    state_series_word1 = tf.reduce_max(state_series, axis=1)\n",
    "\n",
    "with tf.variable_scope(\"word_lstm2\"):\n",
    "    cell = tf.contrib.rnn.BasicLSTMCell(word_state_size)\n",
    "    state_series, current_state = tf.nn.dynamic_rnn(cell, embedded_word[1], sequence_length=path_length[1], initial_state=init_states[0])\n",
    "    state_series_word2 = tf.reduce_max(state_series, axis=1)\n",
    "\n",
    "with tf.variable_scope(\"pos_lstm1\"):\n",
    "    cell = tf.contrib.rnn.BasicLSTMCell(other_state_size)\n",
    "    state_series, current_state = tf.nn.dynamic_rnn(cell, embedded_pos[0], sequence_length=path_length[0],initial_state=init_states[1])\n",
    "    state_series_pos1 = tf.reduce_max(state_series, axis=1)\n",
    "\n",
    "with tf.variable_scope(\"pos_lstm2\"):\n",
    "    cell = tf.contrib.rnn.BasicLSTMCell(other_state_size)\n",
    "    state_series, current_state = tf.nn.dynamic_rnn(cell, embedded_pos[1], sequence_length=path_length[1],initial_state=init_states[1])\n",
    "    state_series_pos2 = tf.reduce_max(state_series, axis=1)\n",
    "\n",
    "with tf.variable_scope(\"dep_lstm1\"):\n",
    "    cell = tf.contrib.rnn.BasicLSTMCell(other_state_size)\n",
    "    state_series, current_state = tf.nn.dynamic_rnn(cell, embedded_dep[0], sequence_length=path_length[0], initial_state=init_states[2])\n",
    "    state_series_dep1 = tf.reduce_max(state_series, axis=1)\n",
    "\n",
    "with tf.variable_scope(\"dep_lstm2\"):\n",
    "    cell = tf.contrib.rnn.BasicLSTMCell(other_state_size)\n",
    "    state_series, current_state = tf.nn.dynamic_rnn(cell, embedded_dep[1], sequence_length=path_length[1], initial_state=init_states[2])\n",
    "    state_series_dep2 = tf.reduce_max(state_series, axis=1)\n",
    "\n",
    "state_series1 = tf.concat([state_series_word1, state_series_pos1, state_series_dep1], 1)\n",
    "state_series2 = tf.concat([state_series_word2, state_series_pos2, state_series_dep2], 1)\n",
    "\n",
    "state_series = tf.concat([state_series1, state_series2], 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "with tf.name_scope(\"hidden_layer\"):\n",
    "    W = tf.Variable(tf.truncated_normal([600, 100], -0.1, 0.1), name=\"W\")\n",
    "    b = tf.Variable(tf.zeros([100]), name=\"b\")\n",
    "    y_hidden_layer = tf.matmul(state_series, W) + b\n",
    "\n",
    "with tf.name_scope(\"softmax_layer\"):\n",
    "    W = tf.Variable(tf.truncated_normal([100, relation_classes], -0.1, 0.1), name=\"W\")\n",
    "    b = tf.Variable(tf.zeros([relation_classes]), name=\"b\")\n",
    "    logits = tf.matmul(y_hidden_layer, W) + b\n",
    "    predictions = tf.argmax(logits, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "tv_all = tf.trainable_variables()\n",
    "tv_regu = []\n",
    "non_reg = [\"word_embedding/W:0\",\"pos_embedding/W:0\",'dep_embedding/W:0',\"global_step:0\",'hidden_layer/b:0','softmax_layer/b:0']\n",
    "for t in tv_all:\n",
    "    if t.name not in non_reg:\n",
    "        if(t.name.find('biases')==-1):\n",
    "            tv_regu.append(t)\n",
    "\n",
    "with tf.name_scope(\"loss\"):\n",
    "    l2_loss = lambda_l2 * tf.reduce_sum([ tf.nn.l2_loss(v) for v in tv_regu ])\n",
    "    loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=y))\n",
    "    total_loss = loss + l2_loss\n",
    "\n",
    "global_step = tf.Variable(0, name=\"global_step\")\n",
    "\n",
    "optimizer = tf.train.AdamOptimizer(0.001).minimize(total_loss, global_step=global_step)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "f = open(data_dir + '/vocab.pkl', 'rb')\n",
    "vocab = pickle.load(f)\n",
    "f.close()\n",
    "\n",
    "word2id = dict((w, i) for i,w in enumerate(vocab))\n",
    "id2word = dict((i, w) for i,w in enumerate(vocab))\n",
    "\n",
    "unknown_token = \"UNKNOWN_TOKEN\"\n",
    "word2id[unknown_token] = word_vocab_size -1\n",
    "id2word[word_vocab_size-1] = unknown_token\n",
    "\n",
    "pos_tags_vocab = []\n",
    "for line in open(data_dir + '/pos_tags.txt'):\n",
    "        pos_tags_vocab.append(line.strip())\n",
    "\n",
    "dep_vocab = []\n",
    "for line in open(data_dir + '/dependency_types.txt'):\n",
    "    dep_vocab.append(line.strip())\n",
    "\n",
    "relation_vocab = []\n",
    "for line in open(data_dir + '/relation_types.txt'):\n",
    "    relation_vocab.append(line.strip())\n",
    "\n",
    "\n",
    "rel2id = dict((w, i) for i,w in enumerate(relation_vocab))\n",
    "id2rel = dict((i, w) for i,w in enumerate(relation_vocab))\n",
    "\n",
    "pos_tag2id = dict((w, i) for i,w in enumerate(pos_tags_vocab))\n",
    "id2pos_tag = dict((i, w) for i,w in enumerate(pos_tags_vocab))\n",
    "\n",
    "dep2id = dict((w, i) for i,w in enumerate(dep_vocab))\n",
    "id2dep = dict((i, w) for i,w in enumerate(dep_vocab))\n",
    "\n",
    "pos_tag2id['OTH'] = 9\n",
    "id2pos_tag[9] = 'OTH'\n",
    "\n",
    "dep2id['OTH'] = 20\n",
    "id2dep[20] = 'OTH'\n",
    "\n",
    "JJ_pos_tags = ['JJ', 'JJR', 'JJS']\n",
    "NN_pos_tags = ['NN', 'NNS', 'NNP', 'NNPS']\n",
    "RB_pos_tags = ['RB', 'RBR', 'RBS']\n",
    "PRP_pos_tags = ['PRP', 'PRP$']\n",
    "VB_pos_tags = ['VB', 'VBD', 'VBG', 'VBN', 'VBP', 'VBZ']\n",
    "_pos_tags = ['CC', 'CD', 'DT', 'IN']\n",
    "\n",
    "def pos_tag(x):\n",
    "    if x in JJ_pos_tags:\n",
    "        return pos_tag2id['JJ']\n",
    "    if x in NN_pos_tags:\n",
    "        return pos_tag2id['NN']\n",
    "    if x in RB_pos_tags:\n",
    "        return pos_tag2id['RB']\n",
    "    if x in PRP_pos_tags:\n",
    "        return pos_tag2id['PRP']\n",
    "    if x in VB_pos_tags:\n",
    "        return pos_tag2id['VB']\n",
    "    if x in _pos_tags:\n",
    "        return pos_tag2id[x]\n",
    "    else:\n",
    "        return 9"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "f = open(data_dir + '/train_paths', 'rb')\n",
    "word_p1, word_p2, dep_p1, dep_p2, pos_p1, pos_p2 = pickle.load(f)\n",
    "f.close()\n",
    "\n",
    "relations = []\n",
    "for line in open(data_dir + '/train_relations.txt'):\n",
    "    relations.append(line.strip().split()[1])\n",
    "\n",
    "length = len(word_p1)\n",
    "num_batches = int(length/batch_size)\n",
    "\n",
    "for i in range(length):\n",
    "    for j, word in enumerate(word_p1[i]):\n",
    "        word = word.lower()\n",
    "        word_p1[i][j] = word if word in word2id else unknown_token \n",
    "    for k, word in enumerate(word_p2[i]):\n",
    "        word = word.lower()\n",
    "        word_p2[i][k] = word if word in word2id else unknown_token \n",
    "    for l, d in enumerate(dep_p1[i]):\n",
    "        dep_p1[i][l] = d if d in dep2id else 'OTH'\n",
    "    for m, d in enumerate(dep_p2[i]):\n",
    "        dep_p2[i][m] = d if d in dep2id else 'OTH'\n",
    "\n",
    "word_p1_ids = np.ones([length, max_len_path],dtype=int)\n",
    "word_p2_ids = np.ones([length, max_len_path],dtype=int)\n",
    "pos_p1_ids = np.ones([length, max_len_path],dtype=int)\n",
    "pos_p2_ids = np.ones([length, max_len_path],dtype=int)\n",
    "dep_p1_ids = np.ones([length, max_len_path],dtype=int)\n",
    "dep_p2_ids = np.ones([length, max_len_path],dtype=int)\n",
    "rel_ids = np.array([rel2id[rel] for rel in relations])\n",
    "path1_len = np.array([len(w) for w in word_p1], dtype=int)\n",
    "path2_len = np.array([len(w) for w in word_p2])\n",
    "\n",
    "for i in range(length):\n",
    "    for j, w in enumerate(word_p1[i]):\n",
    "        word_p1_ids[i][j] = word2id[w]\n",
    "    for j, w in enumerate(word_p2[i]):\n",
    "        word_p2_ids[i][j] = word2id[w]\n",
    "    for j, w in enumerate(pos_p1[i]):\n",
    "        pos_p1_ids[i][j] = pos_tag(w)\n",
    "    for j, w in enumerate(pos_p2[i]):\n",
    "        pos_p2_ids[i][j] = pos_tag(w)\n",
    "    for j, w in enumerate(dep_p1[i]):\n",
    "        dep_p1_ids[i][j] = dep2id[w]\n",
    "    for j, w in enumerate(dep_p2[i]):\n",
    "        dep_p2_ids[i][j] = dep2id[w]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "sess = tf.Session()\n",
    "sess.run(tf.global_variables_initializer())\n",
    "saver = tf.train.Saver()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# f = open('data/word_embedding', 'rb')\n",
    "# word_embedding = pickle.load(f)\n",
    "# f.close()\n",
    "\n",
    "# sess.run(embedding_init, feed_dict={embedding_placeholder:word_embedding})\n",
    "# word_embedding_saver.save(sess, word_embd_dir + '/word_embd')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from checkpoint/modelv1/model\n"
     ]
    }
   ],
   "source": [
    "model = tf.train.latest_checkpoint(model_dir)\n",
    "saver.restore(sess, model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": true,
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# latest_embd = tf.train.latest_checkpoint(word_embd_dir)\n",
    "# word_embedding_saver.restore(sess, latest_embd)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "collapsed": true,
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "num_epochs = 10\n",
    "for i in range(num_epochs):\n",
    "    for j in range(num_batches):\n",
    "        path_dict = [path1_len[j*batch_size:(j+1)*batch_size], path2_len[j*batch_size:(j+1)*batch_size]]\n",
    "        word_dict = [word_p1_ids[j*batch_size:(j+1)*batch_size], word_p2_ids[j*batch_size:(j+1)*batch_size]]\n",
    "        pos_dict = [pos_p1_ids[j*batch_size:(j+1)*batch_size], pos_p2_ids[j*batch_size:(j+1)*batch_size]]\n",
    "        dep_dict = [dep_p1_ids[j*batch_size:(j+1)*batch_size], dep_p2_ids[j*batch_size:(j+1)*batch_size]]\n",
    "        y_dict = rel_ids[j*batch_size:(j+1)*batch_size]\n",
    "        \n",
    "        feed_dict = {\n",
    "            path_length:path_dict,\n",
    "            word_ids:word_dict,\n",
    "            pos_ids:pos_dict,\n",
    "            dep_ids:dep_dict,\n",
    "            y:y_dict}\n",
    "        _, loss, step = sess.run([optimizer, total_loss, global_step], feed_dict)\n",
    "        if step%10==0:\n",
    "            print(\"Step:\", step, \"loss:\",loss)\n",
    "        if step % 1000 == 0:\n",
    "            saver.save(sess, model_dir + '/model')\n",
    "            print(\"Saved Model\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "training accuracy 99.2625\n"
     ]
    }
   ],
   "source": [
    "# training accuracy\n",
    "all_predictions = []\n",
    "for j in range(num_batches):\n",
    "    path_dict = [path1_len[j*batch_size:(j+1)*batch_size], path2_len[j*batch_size:(j+1)*batch_size]]\n",
    "    word_dict = [word_p1_ids[j*batch_size:(j+1)*batch_size], word_p2_ids[j*batch_size:(j+1)*batch_size]]\n",
    "    pos_dict = [pos_p1_ids[j*batch_size:(j+1)*batch_size], pos_p2_ids[j*batch_size:(j+1)*batch_size]]\n",
    "    dep_dict = [dep_p1_ids[j*batch_size:(j+1)*batch_size], dep_p2_ids[j*batch_size:(j+1)*batch_size]]\n",
    "    y_dict = rel_ids[j*batch_size:(j+1)*batch_size]\n",
    "\n",
    "    feed_dict = {\n",
    "        path_length:path_dict,\n",
    "        word_ids:word_dict,\n",
    "        pos_ids:pos_dict,\n",
    "        dep_ids:dep_dict,\n",
    "        y:y_dict}\n",
    "    batch_predictions = sess.run(predictions, feed_dict)\n",
    "    all_predictions.append(batch_predictions)\n",
    "\n",
    "y_pred = []\n",
    "for i in range(num_batches):\n",
    "    for pred in all_predictions[i]:\n",
    "        y_pred.append(pred)\n",
    "\n",
    "count = 0\n",
    "for i in range(batch_size*num_batches):\n",
    "    count += y_pred[i]==rel_ids[i]\n",
    "accuracy = count/(batch_size*num_batches) * 100\n",
    "\n",
    "print(\"training accuracy\", accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "f = open(data_dir + '/test_paths', 'rb')\n",
    "word_p1, word_p2, dep_p1, dep_p2, pos_p1, pos_p2 = pickle.load(f)\n",
    "f.close()\n",
    "\n",
    "relations = []\n",
    "for line in open(data_dir + '/test_relations.txt'):\n",
    "    relations.append(line.strip().split()[0])\n",
    "\n",
    "length = len(word_p1)\n",
    "num_batches = int(length/batch_size)\n",
    "\n",
    "for i in range(length):\n",
    "    for j, word in enumerate(word_p1[i]):\n",
    "        word = word.lower()\n",
    "        word_p1[i][j] = word if word in word2id else unknown_token \n",
    "    for k, word in enumerate(word_p2[i]):\n",
    "        word = word.lower()\n",
    "        word_p2[i][k] = word if word in word2id else unknown_token \n",
    "    for l, d in enumerate(dep_p1[i]):\n",
    "        dep_p1[i][l] = d if d in dep2id else 'OTH'\n",
    "    for m, d in enumerate(dep_p2[i]):\n",
    "        dep_p2[i][m] = d if d in dep2id else 'OTH'\n",
    "\n",
    "word_p1_ids = np.ones([length, max_len_path],dtype=int)\n",
    "word_p2_ids = np.ones([length, max_len_path],dtype=int)\n",
    "pos_p1_ids = np.ones([length, max_len_path],dtype=int)\n",
    "pos_p2_ids = np.ones([length, max_len_path],dtype=int)\n",
    "dep_p1_ids = np.ones([length, max_len_path],dtype=int)\n",
    "dep_p2_ids = np.ones([length, max_len_path],dtype=int)\n",
    "rel_ids = np.array([rel2id[rel] for rel in relations])\n",
    "path1_len = np.array([len(w) for w in word_p1], dtype=int)\n",
    "path2_len = np.array([len(w) for w in word_p2])\n",
    "\n",
    "for i in range(length):\n",
    "    for j, w in enumerate(word_p1[i]):\n",
    "        word_p1_ids[i][j] = word2id[w]\n",
    "    for j, w in enumerate(word_p2[i]):\n",
    "        word_p2_ids[i][j] = word2id[w]\n",
    "    for j, w in enumerate(pos_p1[i]):\n",
    "        pos_p1_ids[i][j] = pos_tag(w)\n",
    "    for j, w in enumerate(pos_p2[i]):\n",
    "        pos_p2_ids[i][j] = pos_tag(w)\n",
    "    for j, w in enumerate(dep_p1[i]):\n",
    "        dep_p1_ids[i][j] = dep2id[w]\n",
    "    for j, w in enumerate(dep_p2[i]):\n",
    "        dep_p2_ids[i][j] = dep2id[w]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test accuracy 61.4022140221\n"
     ]
    }
   ],
   "source": [
    "# test \n",
    "all_predictions = []\n",
    "for j in range(num_batches):\n",
    "    path_dict = [path1_len[j*batch_size:(j+1)*batch_size], path2_len[j*batch_size:(j+1)*batch_size]]\n",
    "    word_dict = [word_p1_ids[j*batch_size:(j+1)*batch_size], word_p2_ids[j*batch_size:(j+1)*batch_size]]\n",
    "    pos_dict = [pos_p1_ids[j*batch_size:(j+1)*batch_size], pos_p2_ids[j*batch_size:(j+1)*batch_size]]\n",
    "    dep_dict = [dep_p1_ids[j*batch_size:(j+1)*batch_size], dep_p2_ids[j*batch_size:(j+1)*batch_size]]\n",
    "    y_dict = rel_ids[j*batch_size:(j+1)*batch_size]\n",
    "\n",
    "    feed_dict = {\n",
    "        path_length:path_dict,\n",
    "        word_ids:word_dict,\n",
    "        pos_ids:pos_dict,\n",
    "        dep_ids:dep_dict,\n",
    "        y:y_dict}\n",
    "    batch_predictions = sess.run(predictions, feed_dict)\n",
    "    all_predictions.append(batch_predictions)\n",
    "\n",
    "y_pred = []\n",
    "for i in range(num_batches):\n",
    "    for pred in all_predictions[i]:\n",
    "        y_pred.append(pred)\n",
    "\n",
    "count = 0\n",
    "for i in range(batch_size*num_batches):\n",
    "    count += y_pred[i]==rel_ids[i]\n",
    "accuracy = count/(batch_size*num_batches) * 100\n",
    "\n",
    "print(\"test accuracy\", accuracy)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
